
import numpy as np
import tqdm

from src.imports.import_profile import (
    import_distances, import_coordinates, import_radius,
)

from src.imports.import_clustering_sets import import_clustering_sets
import matplotlib.pyplot as plt


from src.imports.import_profile import import_profile


if __name__ == "__main__":

    num_tests = 100
    num_voters = 1000
    num_candidates = 100
    lower_radius = 0.015
    upper_radius = 0.15

    num_clusters = 5

    method = 'hierarchical_minavg'


    base_distances = ['norm_hamming', 'geom_hamming', 'rms_hamming',
                      'jaccard', 'geom_jaccard', 'rms_jaccard']


    for base_distance in base_distances:
        sat = 0
        # IMPORT CLUSTERS
        c_sets_path = f"output/sampled/euclidean_1d/{method}_{base_distance}_c_sets_{num_clusters}_{num_voters}_{num_candidates}_{lower_radius}_{upper_radius}.txt"
        c_sets = import_clustering_sets(c_sets_path)

        for t in tqdm.tqdm(range(num_tests, )):
            path = f'data/sampled/euclidean_1d/profiles/profile_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}'
            P = import_profile(path)

            # print number of ones in P
            sat += np.sum(P)/num_voters/num_candidates

            # IMPORT COORDINATES
            c_points_path = f'data/sampled/euclidean_1d/coordinates/c_points_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.csv'
            c_coordinates = import_coordinates(c_points_path)
            c_radius_path = f'data/sampled/euclidean_1d/coordinates/c_radius_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.csv'
            c_radius = import_radius(c_radius_path)

            v_points_path = f'data/sampled/euclidean_1d/coordinates/v_points_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.csv'
            v_coordinates = import_coordinates(v_points_path)


            # PLOT

            clusters = []
            for k in range(num_clusters):
                clusters.append([i for i in c_sets[str(t)][str(k)]])


            COLORS = {
                'norm_hamming': ['red', 'green', 'blue', 'purple', 'orange'],
                'geom_hamming': ['purple', 'red', 'blue', 'green', 'orange'],
                'rms_hamming': ['green', 'red', 'blue', 'purple', 'orange'],

                'jaccard': ['purple', 'red', 'green', 'orange', 'blue'],
                'geom_jaccard': ['purple', 'blue', 'orange', 'green', 'red'],
                'rms_jaccard': ['purple', 'red', 'blue', 'green', 'orange'],
            }
            # all candidates from cluster 0 should have COLORS[0] color and son on
            colors = []
            for i in range(num_candidates):
                for k in range(num_clusters):
                    if i in clusters[k]:
                        colors.append(COLORS[base_distance][k])
                        break

            plt.figure(figsize=(10, 3))

            plt.scatter(c_coordinates[:, 0],
                        [0 for _ in range(num_candidates)],
                        c=colors,
                        s=(c_radius*100)**2,
                        alpha=0.5
                        )

            for i in range(num_candidates):
                plt.plot([c_coordinates[i, 0], c_coordinates[i, 0]],
                         [-c_radius[i], c_radius[i]],
                         color=colors[i],
                         alpha=0.5,
                         linewidth=2
                         )

            plt.axis('off')
            plt.savefig(f'images/euclidean_1d/k{num_clusters}/{method}_{base_distance}_{num_voters}_{num_candidates}_{t}_{lower_radius}_{upper_radius}.png', dpi=300, bbox_inches='tight')
            plt.clf()

